use std::{
    collections::HashMap,
    net::SocketAddr,
    ptr,
    sync::{atomic::AtomicBool, Arc},
    time::Duration,
};

use common::base::id_generator;
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt, BufReader},
    net::{tcp::OwnedReadHalf, TcpListener, TcpStream},
    sync::{oneshot::Sender, RwLock},
    time::sleep,
};
use tracing::{error, info, warn};

use super::{net::SocketEventHandler, session::Session};

pub struct TCPNet {
    pub listen_port: u16,
    listen_ip: String,
    is_close: AtomicBool,
    is_exit: AtomicBool,
    socket_event_handler: Arc<Box<dyn SocketEventHandler>>,
    session_map: RwLock<HashMap<i64, Arc<RwLock<Box<Session>>>>>,
}

impl TCPNet {
    pub fn new(socket_event_handler: Box<dyn SocketEventHandler>) -> Self {
        TCPNet {
            listen_port: 0,
            listen_ip: String::default(),
            is_close: AtomicBool::new(false),
            is_exit: AtomicBool::new(false),
            socket_event_handler: Arc::new(socket_event_handler),
            session_map: RwLock::new(HashMap::new()),
        }
    }
    /**
     * 关闭 session
     */
    async fn remove_session(&self, session_id: i64) {
        let mut session_map = self.session_map.write().await;
        let session = session_map.remove(&session_id);
        if session.is_none() {
            return;
        }
        let session = session.unwrap();

        let mut session = session.write().await;

        let ret = session.writer.shutdown().await;

        if ret.is_err() {
            warn!("关闭 连接失败！");
        }

        self.socket_event_handler.on_disconnected(session_id).await;
    }

    async fn start_recv(
        reader: OwnedReadHalf,
        session: Arc<RwLock<Box<Session>>>,
        net_ins: Arc<Box<TCPNet>>,
    ) {
        let (ip, port, id) = {
            let session = session.read().await;
            (session.get_ip(), session.get_port(), session.get_id())
        };

        let mut reader = BufReader::new(reader);
        let mut header = [0u8; 4];
        let max_message_size = net_ins.socket_event_handler.max_message_size();
        loop {
            let result = reader.read_exact(&mut header).await;
            if result.is_err() {
                break;
            }
            let msg_size = u32::from_le_bytes(header);
            if msg_size > max_message_size {
                warn!(
                    "ip:{}, port：{}，发送的消息大小为：{}，超过最大值:{} 关闭连接",
                    ip, port, msg_size, max_message_size
                );
                break;
            }

            let message_len = msg_size as usize;

            let mut message: Vec<u8> = Vec::with_capacity(message_len);
            unsafe {
                message.set_len(message_len);

                ptr::copy_nonoverlapping(header.as_ptr(), message.as_mut_ptr(), 4);
            }
            let msg_buffer = &mut message[4..];

            let result = reader.read_exact(msg_buffer).await;
            if result.is_err() {
                break;
            }


            net_ins.socket_event_handler.on_receive(id, message).await;

            // 更新信息
            {
                let mut session = session.write().await;
                session.update_recv_timestamp();
            }
        }
        warn!("关闭 session ：{} ip:{} port : {} 失败", id, ip, port);

        net_ins.remove_session(id).await;
    }

    pub async fn handle_session(
        self: &Arc<Box<TCPNet>>,
        stream: TcpStream,
        addr: SocketAddr,
        is_accept: bool,
    ) -> i64 {
        let (reader, writer) = stream.into_split();

        let session = Arc::new(RwLock::new(Session::new(addr, false, writer)));
        let recv_session = session.clone();

        let session_id = {
            let mut session_write_map = self.session_map.write().await;
            let session_id = {
                let session = session.read().await;
                session.get_id()
            };
            session_write_map.insert(session_id, session);

            session_id
        };

        let clone_net = self.clone();

        if is_accept {
            self.socket_event_handler.on_accept(session_id).await;
        } else {
            self.socket_event_handler.on_connect(session_id).await;
        }

        tokio::spawn(async move {
            Self::start_recv(reader, recv_session, clone_net).await;
        });
        session_id
    }

    pub async fn start(self: Arc<Box<TCPNet>>, statrup_noticer: Sender<bool>) {
        let listen_ip_and_port = format!("{}:{}", self.listen_ip, self.listen_port);
        let listener = TcpListener::bind(listen_ip_and_port).await;
        if listener.is_err() {
            error!(
                "启动 xrpc 服务: {}:{} 失败, 原因:{:?}",
                self.listen_ip, self.listen_port, listener.err().unwrap()
            );
            statrup_noticer.send(false).unwrap();
            return;
        }

        info!("启动 xrpc 服务: {}:{}", self.listen_ip, self.listen_port);
        statrup_noticer.send(true).unwrap();

        let listener = listener.unwrap();
        while !self.is_close.load(std::sync::atomic::Ordering::Relaxed) {
            let socket = listener.accept().await;
            if socket.is_err() {
                warn!("连接出错...");
                continue;
            }
            let (stream, addr) = socket.unwrap();

            self.handle_session(stream, addr, true).await;
        }
        self.is_exit
            .store(true, std::sync::atomic::Ordering::Relaxed);
    }

    /**
     * 初始化
     */
    pub fn init(&mut self, ip: String, port: u16) {
        self.listen_port = port;
        self.listen_ip = ip;
    }

    pub async fn connect_to(self: &Arc<Box<Self>>, conn_id: i64) -> Option<i64> {
        let identity_info = id_generator::get_node_id_info(conn_id);

        let str_ip = format!("{}:{}", identity_info.ip, identity_info.port);

        info!("连接到节点:{} ...", &str_ip);

        let addr = str_ip.parse::<SocketAddr>().unwrap();

        let stream = TcpStream::connect(&addr).await;

        if stream.is_err() {
            return None;
        }

        let stream = stream.unwrap();

        let session_id = self.handle_session(stream, addr, false).await;

        Some(session_id)
    }

    /**
     * 开始
     */

    pub async fn send_message(&self, session_id: i64, message: &Vec<u8>) -> bool {
        let session = {
            let session_map = self.session_map.read().await;
            let session = session_map.get(&session_id);
            if session.is_none() {
                None
            } else {
                Some(session.unwrap().clone())
            }
        };

        if session.is_none() {
            return false;
        }
        let session = session.unwrap();
        let mut session = session.write().await;

        info!("发送消息, 长度：{}", message.len());

        let ret = session.writer.write(&mut message.as_slice()).await;

        if ret.is_err() {
            info!("发送消息失败, 长度：{}", message.len());
            return false;
        }

        let ret = session.writer.flush().await;

        ret.is_ok()
    }

    pub async fn close_session(&self, session_id: i64) {
        self.remove_session(session_id).await;
    }

    /**
     *
     */
    pub async fn get_session_conn_id(&self, session_id: i64) -> Option<i64> {
        let session_map = self.session_map.read().await;

        let session = session_map.get(&session_id);

        if session.is_none() {
            return None;
        }
        let session = session.unwrap();

        let session = session.read().await;

        Some(session.get_conn_info())
    }

    pub async fn get_ip_by_session_id(&self, session_id: i64) -> Option<String> {
        let session_map = self.session_map.read().await;

        let session = session_map.get(&session_id);

        if session.is_none() {
            return None;
        }

        let session = session.unwrap().read().await;

        Some(session.get_ip())
    }

    /**
     * 关闭
     */
    pub async fn stop(&self) {
        self.is_close
            .store(true, std::sync::atomic::Ordering::Relaxed);

        while self.is_exit.load(std::sync::atomic::Ordering::Relaxed) {
            sleep(Duration::from_millis(500)).await;
        }
    }
}
